{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d70339c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.stats import wasserstein_distance\n",
    "import helpers as ph\n",
    "import seaborn as sns\n",
    "import dataframe_image as dfi\n",
    "\n",
    "styles = ph.VIS_STYLES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fdcd906",
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULTS_DIR = f'./data/distributions/'\n",
    "CONTEXT = 'default'\n",
    "SAVEFIG = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "271c7575",
   "metadata": {},
   "source": [
    "## Load human and LM opinion distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4f60044",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_df, human_df = [], []\n",
    "for wave in ph.PEW_SURVEY_LIST:\n",
    "    SURVEY_NAME = f'American_Trends_Panel_W{wave}'\n",
    "\n",
    "    cdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_combined.csv'))\n",
    "    cdf['survey'] = f'ATP {wave}'\n",
    "    combined_df.append(cdf)\n",
    "    \n",
    "    hdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_baseline.csv'))\n",
    "    hdf['survey'] = f'ATP {wave}'\n",
    "    human_df.append(hdf)\n",
    "combined_df, human_df = pd.concat(combined_df), pd.concat(human_df)\n",
    "combined_df['Source'] = combined_df.apply(lambda x: 'AI21 Labs' if 'j1-' in x['model_name'].lower() else 'OpenAI',\n",
    "                                          axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5242f3ae",
   "metadata": {},
   "source": [
    "## Compare refusals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2a2e617",
   "metadata": {},
   "outputs": [],
   "source": [
    "KEYS = ['Source', 'model_name', 'attribute', 'model_order']\n",
    "\n",
    "model_refusals = combined_df[combined_df['group'] == 'Overall']\n",
    "model_refusals = model_refusals.groupby(KEYS, as_index=False).agg({'R_M': lambda x: 100 * np.mean(x)}) \\\n",
    "         .sort_values(by=['model_order']).rename(columns={'R_M': 'Refusal'})\n",
    "human_refusals = combined_df.groupby(['group', 'group_order'], as_index=False) \\\n",
    "                .agg({'R_H':  lambda x: 100 * np.mean(x)}).rename(columns={'R_H': 'Refusal'})\n",
    "human_refusals['Source'] = 'humans'\n",
    "human_refusals['model_name'] = 'overall'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d0cc9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "refusal_table = pd.concat([human_refusals[human_refusals['group'] == 'Overall'],\n",
    "                           model_refusals]) \\\n",
    "                .rename(columns={'model_name': ''})\n",
    "\n",
    "refusal_table = pd.pivot_table(refusal_table, \n",
    "                       columns=['Source', ''], \n",
    "                       values=\"Refusal\", \n",
    "                       sort=False)\n",
    "table_vis = refusal_table.style.background_gradient(\"Reds\", axis=1)\\\n",
    "                        .set_table_styles(ph.VIS_STYLES)  \\\n",
    "                        .set_properties(**{\"font-size\":\"0.7rem\"}).format(precision=3)\n",
    "if SAVEFIG: table_vis.hide_index().export_png(f'./figures/refusals.png')\n",
    "display(table_vis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "794054c1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
